package com.tanhua.spark.mongo;

import com.mongodb.MongoClient;
import com.mongodb.MongoCredential;
import com.mongodb.ServerAddress;
import com.mongodb.client.MongoCollection;
import com.mongodb.spark.MongoSpark;
import org.apache.commons.lang3.RandomUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.bson.Document;
import org.bson.types.ObjectId;
import org.joda.time.DateTime;
import scala.Tuple2;

import java.io.IOException;
import java.io.InputStream;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

public class SparkUserRecommend {

    private static String MONGODB_HOST = System.getenv("MONGODB_HOST");
    private static String MONGODB_PORT = System.getenv("MONGODB_PORT");
    private static String MONGODB_USERNAME = System.getenv("MONGODB_USERNAME");
    private static String MONGODB_PASSWORD = System.getenv("MONGODB_PASSWORD");
    private final static String MONGODB_DATABASE = System.getenv("MONGODB_DATABASE") == null ? "tanhua" : System.getenv("MONGODB_DATABASE");
    private final static String MONGODB_COLLECTION = System.getenv("MONGODB_COLLECTION") == null ? "recommend_user" : System.getenv("MONGODB_COLLECTION");

    private static String JDBC_URL = System.getenv("JDBC_URL");
    private static String JDBC_DRIVER = System.getenv("JDBC_DRIVER");
    private static String JDBC_USER = System.getenv("JDBC_USER");
    private static String JDBC_PASSWORD = System.getenv("JDBC_PASSWORD");
    private final static String JDBC_TABLE = System.getenv("JDBC_TABLE") == null ? "tb_user_info" : System.getenv("JDBC_TABLE");

    private final static Integer SCHEDULE_PERIOD = System.getenv("SCHEDULE_PERIOD") == null ? 10 : Integer.valueOf(System.getenv("SCHEDULE_PERIOD"));

    private final static Integer RECOMMEND_COUNT = System.getenv("RECOMMEND_COUNT") == null ? 50 : Integer.valueOf(System.getenv("RECOMMEND_COUNT"));

    static {
        //加载外部的配置文件，app.properties
        try {
            InputStream inputStream = SparkQunaZi.class.getClassLoader().getResourceAsStream("app-user-recommend.properties");
            Properties properties = new Properties();
            properties.load(inputStream);

            MONGODB_HOST = MONGODB_HOST == null ? properties.getProperty("mongodb.host") : MONGODB_HOST;
            MONGODB_PORT = MONGODB_PORT == null ? properties.getProperty("mongodb.port") : MONGODB_PORT;
            MONGODB_USERNAME = MONGODB_USERNAME == null ? properties.getProperty("mongodb.username") : MONGODB_USERNAME;
            MONGODB_PASSWORD = MONGODB_PASSWORD == null ? properties.getProperty("mongodb.password") : MONGODB_PASSWORD;

            JDBC_URL = JDBC_URL == null ? properties.getProperty("jdbc.url") : JDBC_URL;
            JDBC_DRIVER = JDBC_DRIVER == null ? properties.getProperty("jdbc.driver-class-name") : JDBC_DRIVER;
            JDBC_USER = JDBC_USER == null ? properties.getProperty("jdbc.username") : JDBC_USER;
            JDBC_PASSWORD = JDBC_PASSWORD == null ? properties.getProperty("jdbc.password") : JDBC_PASSWORD;

        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] args) {
        Runnable runnable = () -> {
            try {
                execute();
            } catch (Exception e) {
                e.printStackTrace();
            }
        };
        ScheduledExecutorService service = Executors
                .newSingleThreadScheduledExecutor();
        // 第二个参数为首次执行的延时时间，第三个参数为定时执行的间隔时间
        service.scheduleAtFixedRate(runnable, 0, SCHEDULE_PERIOD, TimeUnit.MINUTES);
    }

    public static void execute() throws Exception {
        //构建Spark配置
        SparkConf sparkConf = new SparkConf()
                .setAppName("SparkUserRecommend")
                .setMaster("local[*]")
                .set("spark.driver.host", "localhost")
                .set("spark.mongodb.output.uri", "mongodb://" + MONGODB_USERNAME + ":" + MONGODB_PASSWORD + "@" + MONGODB_HOST + ":" + MONGODB_PORT + "/admin?readPreference=primaryPreferred")
                .set("spark.mongodb.output.database", MONGODB_DATABASE)
                .set("spark.mongodb.output.collection", MONGODB_COLLECTION);

        //加载mysql数据
        SparkSession sparkSession = SparkSession.builder().config(sparkConf).getOrCreate();

        // 设置数据库连接信息
        Properties connectionProperties = new Properties();
        connectionProperties.put("driver", JDBC_DRIVER);
        connectionProperties.put("user", JDBC_USER);
        connectionProperties.put("password", JDBC_PASSWORD);

        JavaRDD<Row> userInfoRdd = sparkSession.read().jdbc(JDBC_URL, JDBC_TABLE, connectionProperties).toJavaRDD();

        //用户列表
        List<Long> userIds = userInfoRdd.map(v -> v.getLong(1)).collect();

        //计算出这张表数据的 笛卡尔积
        JavaPairRDD<Row, Row> cartesian = userInfoRdd.cartesian(userInfoRdd);

        //计算用户的相似度
        JavaPairRDD<Long, Rating> javaPairRDD = cartesian.mapToPair(row -> {
            Row row1 = row._1();
            Row row2 = row._2();

            Long userId1 = row1.getLong(1);
            Long userId2 = row2.getLong(1);

            Long key = userId1 + userId2 + RandomUtils.nextLong();

            // 自己与自己对比
            if (userId1.longValue() == userId2.longValue()) {
                return new Tuple2<>(key % 10, new Rating(userId1.intValue(), userId2.intValue(), 0));
            }

            double score = 0;

            //计算年龄差
            int ageDiff = Math.abs(row1.getInt(6) - row2.getInt(6));
            if (ageDiff <= 2) {
                score += 30;
            } else if (ageDiff >= 3 && ageDiff <= 5) {
                score += 20;
            } else if (ageDiff > 5 && ageDiff <= 10) {
                score += 10;
            }

            // 计算性别
            if (row1.getInt(5) != row2.getInt(5)) {
                score += 30;
            }

            // 计算城市
            String city1 = StringUtils.substringBefore(row1.getString(8), "-");
            String city2 = StringUtils.substringBefore(row2.getString(8), "-");
            if (StringUtils.equals(city1, city2)) {
                score += 20;
            }

            // 计算学历
            String edu1 = row1.getString(7);
            String edu2 = row2.getString(7);
            if (StringUtils.equals(edu1, edu2)) {
                score += 20;
            }

            Rating rating = new Rating(userId1.intValue(), userId2.intValue(), score);
            return new Tuple2<>(key % 10, rating);

        });

        //MLlib进行计算最佳的推荐模型
        MLlibRecommend mLlibRecommend = new MLlibRecommend();
        MatrixFactorizationModel bestModel = mLlibRecommend.bestModel(javaPairRDD);

        //将数据写入到MongoDB中
        JavaSparkContext jsc = new JavaSparkContext(sparkSession.sparkContext());


        //清空数据
        clearMongoDBData();

        for (Long userId : userIds) {
            Rating[] ratings = bestModel.recommendProducts(userId.intValue(), RECOMMEND_COUNT);
            JavaRDD<Document> documentJavaRDD = jsc.parallelize(Arrays.asList(ratings)).filter(rating -> rating.user() != rating.product()).map(v1 -> {
                Document document = new Document();

                document.put("_id", ObjectId.get());
                document.put("userId", v1.product());
                document.put("toUserId", v1.user());
                //得分，保留2位小数
                double score = new BigDecimal(v1.rating()).setScale(2, BigDecimal.ROUND_DOWN).doubleValue();
                document.put("score", score);
                document.put("date", new DateTime().toString("yyyy/MM/dd"));

                return document;
            });

            MongoSpark.save(documentJavaRDD);

        }

    }

    private static void clearMongoDBData() {
        List<ServerAddress> adds = new ArrayList<>();
        //ServerAddress()两个参数分别为 服务器地址 和 端口
        ServerAddress serverAddress = new ServerAddress(MONGODB_HOST, Integer.valueOf(MONGODB_PORT));
        adds.add(serverAddress);

        List<MongoCredential> credentials = new ArrayList<>();
        //MongoCredential.createScramSha1Credential()三个参数分别为 用户名 数据库名称 密码
        MongoCredential mongoCredential = MongoCredential.createScramSha1Credential(MONGODB_USERNAME,
                "admin",
                MONGODB_PASSWORD.toCharArray());
        credentials.add(mongoCredential);

        //通过连接认证获取MongoDB连接
        MongoClient mongoClient = new MongoClient(adds, credentials);

        MongoCollection<Document> collection = mongoClient.getDatabase(MONGODB_DATABASE).getCollection(MONGODB_COLLECTION);

        collection.drop();

        mongoClient.close();
    }
}
